## Compute Pr(H = H_1|e,n) for two-arm count-data
post_prob_H1_cnt_sim <- function(Y_t, Y_c, n, theta_star, a1, b1, a2, b2, q = 0.5, sim = 10000){
  
  # Posterior p_t parameters
  pt_pa <- a1 + Y_t
  pt_pb <- b1 + n
  # Posterior p1|Y_t, n ~ Beta(pt_pa, pt_pb)
  pt_post <- rgamma(sim, pt_pa, rate = pt_pb)
  
  # Posterior p_c parameters
  pc_pa <- a2 + Y_c
  pc_pb <- b2 + n
  # Posterior p2|Y_c, n ~ Beta(pc_pa, pc_pb)
  pc_post <- rgamma(sim, pc_pa, rate = pc_pb)
  
  # p1 - p2
  pd_post <- pt_post - pc_post
  
  # Compute Pr(p1 - p2 > delta|data) = #(p1 - p2 > delta)/sim
  post_prob <- sum(pd_post > theta_star)/sim
  
  # Compute C1 and C0 through simulation
  diff_prior <- rgamma(sim, a1, b1) - rgamma(sim, a2, b2)
  C1 <- sum(diff_prior > theta_star)/sim
  C0 <- 1 - C1
  
  if(q == 0.5){
    result <- (C0*post_prob)/(C1 + (C0-C1)*post_prob)
    return(result)
  }else{
    denom <- ((1 - q)/C0) + ((q/C1)-(1-q)/C0)*post_prob
    result <- ((q/C1)*post_prob)/denom
    return(result)
  }
}

# -------------------------------------------------------------------------
# Check if the parameters for count-data outcome satisfies Proposition 3 --
# -------------------------------------------------------------------------
# Parameters:
#   e: evidence; y_bar0: reference for control;
#   a, b: hyperparameters; n_max: maximum candidate sample size
check_param <- function(e, y_bar0, a, b){
  term1 <- lgamma(a+2*e) + lgamma(a) - 2*lgamma(a+e)
  term2 <- a*(2*log(b+1) - log(b) - log(b+2))
  term3 <- e*(log(b+1) - log(b+2))
  return(exp(term1+term2+term3))
}

## Function to find optmial theta_0
find_theta0_count <- function(y_bar0, e, theta_star){
  
  B <- 2*theta_star - (2*y_bar0 + e)
  C <- -y_bar0*theta_star
  root1 <- (-B + sqrt(B^2-4*2*C))/(2*2)
  root2 <- (-B - sqrt(B^2-4*2*C))/(2*2)
  return(max(root1, root2))
}

## Function to check whether given a y_bar0, e, and theta_star, 
## the cal{B} term is greater than 0 at sample size n.
n_max_gt_0_gamma <- function(e, y_bar0, theta_star, a1, b1, a2, b2, n){
  
  lc1 <- lgamma(a1+(n+1)*(y_bar0+e))-(a1+(n+1)*(y_bar0+e))*log(b1+n+1)
  lc2 <- lgamma(a2+(n+1)*y_bar0)-(a2+(n+1)*y_bar0)*log(b2+n+1)
  lc3 <- lgamma(a1+n*(y_bar0+e))-(a1+n*(y_bar0+e))*log(b1+n)
  lc4 <- lgamma(a2+n*y_bar0) - (a2+n*y_bar0)*log(b2+n)
  
  opt_theta0 <- find_theta0_count(y_bar0, e, theta_star)
  print(opt_theta0)
  
  log_term1 <- lc1 + lc2 - lc3 - lc4
  log_term2 <- (y_bar0+e)*log(opt_theta0+theta_star) + y_bar0*log(opt_theta0) - 
    (2*opt_theta0 + theta_star)
  
  B_2 <- exp(log_term1) - exp(log_term2)
  print(B_2)
  return(B_2)
}

## Step function for find_n_min
gamma_func_step <- function(e, y_bar0, theta_star, a1, b1, a2, b2, n){
  alpha1 <- a1 + n * (y_bar0 + e)
  beta1 <- b1 + n
  alpha2 <- a2 + n * y_bar0
  beta2 <- b2 + n
  
  term1 <- pgamma(theta_star, shape = alpha1, rate = beta1)
  
  term2 <- integrate(
    f = function(x){
      dgamma(x, shape = a1, rate = b1) * 
        pgamma(x - theta_star, shape = a2, rate = b2, lower.tail = FALSE)
    },
    lower = theta_star, upper = Inf, rel.tol = 1e-8, abs.tol = 1e-12
  )$value
  
  result <- term1 + term2
  return(result)
}

# -------------------------------------------------------------------------
# BESS Algorithm 4: Compute the n_min for two-arm trials with count data --
# -------------------------------------------------------------------------
# Parameters:
#   theta_star: clinically minimum effect size; 
#   e: evidence; y_bar0: reference rate on control
#   a1, b1, a2, b2: hyperparameters;
#   n_max: maximum candidate sample size;
find_min_n_BESS2p <- function(theta_star, e, y_bar0, 
                              a1, b1, a2, b2, n_max) {
  
  # 1. sanity‐check at n_max
  if(n_max_gt_0_gamma(e, y_bar0, theta_star, 
                      a1, b1, a2, b2, n_max) < 0) {
    stop("Please increase n_max.")
  }
  
  # 2. candidate n’s
  n_seq <- seq_len(n_max)
  
  # 3. vectorised differences γ(n) - γ(n+1)
  diffs <- vapply(n_seq,
    function(n) {
      gamma_func_step(e, y_bar0, theta_star, 
                      a1, b1, a2, b2, n) -
        gamma_func_step(e, y_bar0, theta_star, 
                        a1, b1, a2, b2, n + 1)
    }, numeric(1))
  
  # 4. if no positive difference anywhere, return -1
  if (max(diffs) < 0) {
    stop("Please increase nmax.")
  }
  
  # 5. compute right‐cumulative minima and pick first non‐negative
  right_min <- rev(cummin(rev(diffs)))
  return(which(right_min >= 0)[1])
}

# ------------------------------------------------------------------------
# Unit-Testing BESS Algo 4 -----------------------------------------------
# ------------------------------------------------------------------------
#e <- 0.1
#y_bar0 <- 0.3
#theta_star <- 0.05
#find_min_n_BESS2p(theta_star, e, y_bar0, 1, 2, 1, 2, 700)

# -------------------------------------------------------------------------
# BESS Algorithm 2': Find sample size for two-arm trials count data -------
# -------------------------------------------------------------------------
# Parameters:
#   theta_star: clinically minimum effect size;
#   c: confidence level; e: evidence; y_bar0: the reference rate for control;
#   a1 = b1 = a2 = b2 = 0.5: hyperparameters;
#   n_min, n_max: minimum and maximum candidate sample sizes; 
#   q = 0.5: prior prob of H_1; sim = 10000: simulation number;
BESS_count_data <- function(theta_star, c, e, ybar0,
                            a1 = 1, b1 = 2, a2 = 1, b2 = 2,
                            n_min = 5, n_max = 100, q = 0.5, sim = 10000) {
  
  # 1. construct candidate n's
  n_seq  <- seq(n_min, n_max)
  
  # 2. vectorised y0, yd, y1
  y0_vec <- round(ybar0 * n_seq)
  yd_vec <- floor(e      * n_seq)
  y1_vec <- y0_vec + yd_vec
  
  # 3. posterior P(H1) for each (y1,y0,n)
  xi_vec <- vapply(seq_along(n_seq),
    function(n) {
      post_prob_H1_cnt_sim(Y_t = y1_vec[n], Y_c = y0_vec[n], n = n_seq[n],
        theta_star = theta_star, a1 = a1, b1 = b1, a2 = a2, b2 = b2,
        q = q, sim = sim)},
    numeric(1))
  
  # 4. find first hit (or default to last n)
  hit_idx <- which(xi_vec >= c)[1]
  if (is.na(hit_idx)) hit_idx <- length(n_seq)
  
  # 5. return results
  return(list(n = n_seq[hit_idx], yd = yd_vec[hit_idx], prob_h1_vec = xi_vec,
    yd_vec = yd_vec, y1_vec = y1_vec, y0_vec = y0_vec))
}

# ------------------------------------------------------------------------
# Unit-Testing BESS Algo 2p ----------------------------------------------
# ------------------------------------------------------------------------
#set.seed(12345)
#e <- 0.1
#y_bar0 <- 0.3
#theta_star <- 0.05
#c <- 0.8
#BESS_count_data(theta_star, c, e, y_bar0, n_max = 700)
